#pragma once

#include <ATen/core/builtin_function.h>
#include <ATen/core/stack.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/backends/backend_interface.h>
#include <torch/csrc/jit/backends/backend_resolver.h>
#include <torch/csrc/jit/frontend/code_template.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/frontend/sugared_value.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/custom_class.h>

namespace torch {
namespace jit {

// Static registration API for backends.
template <class TBackendInterface>
class backend {
  static_assert(
      std::is_base_of<PyTorchBackendInterface, TBackendInterface>::value,
      "torch::jit::backend_<T> requires T to inherit from PyTorchBackendInterface");
  constexpr static auto kBackendsNamespace = "__backends__";
  std::string backend_name_;

 public:
  explicit backend(const std::string& name) : backend_name_(name) {
    static auto cls = torch::class_<TBackendInterface>(kBackendsNamespace, name)
                          .def(torch::init<>())
                          ._def_unboxed(
                              "preprocess",
                              detail::getPreprocessFunc<TBackendInterface>(),
                              detail::getPreprocessSchema())
                          ._def_unboxed(
                              "compile",
                              detail::getCompileFunc<TBackendInterface>(),
                              detail::getCompileSchema())
                          ._def_unboxed(
                              "execute",
                              detail::getExecuteFunc<TBackendInterface>(),
                              detail::getExecuteSchema());
  }

  // Generates and returns a function that takes a Module and a lowering
  // specification in the form of a dictionary. The caller is responsible for
  // binding this into a CPython module.
  std::function<Module(Module, py::dict)> generateToBackendFn() {
    const c10::QualifiedName qual_backend_name(
        {"__torch__", "torch", "classes", kBackendsNamespace, backend_name_});
    const std::string backend_name = qual_backend_name.name();

    return [=](Module orig_module, py::dict method_compile_spec) {
      // TODO: Validate method_compile_spec.

      // Clone orig_module to make sure backend transformation is
      // functional.
      auto cloned_module = orig_module.clone();

      // Represents of a Type of Dict[str, Any].
      auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());

      // Generate LoweredModule.
      Module loweredModule(
          "torch.jit." + backend_name + "LoweredModule",
          get_python_cu(),
          /*should_mangle=*/true);

      // Generate attributes.
      // This is the original cloned and preprocessed module.
      loweredModule.register_attribute(
          "__processed_module",
          AnyType::get(),
          cloned_module._ivalue(),
          /*is_param=*/false);

      // This is for the method_compile_spec passed in to to_<backend> or
      // loaded from an exported model.
      loweredModule.register_attribute(
          "__method_compile_spec",
          any_dict_ty,
          toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
          /*is_param=*/false);

      // This is a pointer to a backend instance that is used to access
      // compile and execute functions.
      auto cls = getCustomClass(qual_backend_name.qualifiedName());
      TORCH_INTERNAL_ASSERT(cls);
      c10::intrusive_ptr<torch::CustomClassHolder> backend;
      loweredModule.register_attribute(
          "__backend", cls, IValue::make_capsule(backend));

      // This is the list of opaque backend handles returned by
      // backend.compile.
      loweredModule.register_attribute(
          "__handles",
          any_dict_ty,
          c10::impl::GenericDict(
              any_dict_ty->getKeyType(), any_dict_ty->getValueType()),
          /*is_param=*/false);

      // Methods.

      // This is a helper function for creating a new instance of the
      // backend class.
      static const auto create_backend_ct = CodeTemplate(R"(
            def __create_backend(self):
                self.__backend = $name()
            )");
      TemplateEnv create_backend_te;
      create_backend_te.s("name", qual_backend_name.qualifiedName());
      loweredModule.define(
          create_backend_ct.format(create_backend_te), loweredModuleResolver());

      // getstate and setstate are for serialization/deserialization of the
      // LoweredModule.
      loweredModule.define(
          R"(
            def __getstate__(self):
                return self.__method_compile_spec, self.__processed_module
            )",
          loweredModuleResolver());

      loweredModule.define(
          R"(
            def __setstate__(self, state):
                self.__method_compile_spec = state[0]
                self.__processed_module = state[1]
                self.__create_backend()
                self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
            )",
          loweredModuleResolver());

      // This is never called during compilation or execution, but is needed
      // to generate the LoweredModule because we don't have access to an
      // instance of the backend as a C++ object with which to call
      // preprocess.
      loweredModule.define(
          R"(
            def __preprocess(self, mod: Any, method_compile_spec: Dict[str, Any]):
                self.__create_backend()
                self.__processed_module = self.__backend.preprocess(mod, method_compile_spec)
          )",
          loweredModuleResolver());

      // This loop generates one method on the LoweredModule for every key
      // in method_compile_spec.
      for (auto& e : method_compile_spec) {
        std::string method_name = py::cast<std::string>(e.first);
        static const auto method_ct = CodeTemplate(R"(
            def $method(self${,def_inputs}):
                typed_inputs: List[Any] = [${fwd_inputs,}]
                $ret, = self.__backend.execute(self.__handles["$method"], typed_inputs)
                ${refine,}
                return $ret
            )");

        TemplateEnv method_te;
        method_te.s("method", method_name);
        auto method = orig_module.get_method(method_name);
        auto& function = method.function();
        auto schema = function.getSchema();

        // Generate the inputs for the function signature (def_inputs) and
        // for passing to backend.execute (fwd_inputs).
        std::vector<std::string> def_inputs, fwd_inputs;
        for (const auto& arg : schema.arguments()) {
          auto name = arg.name();

          // Skip self since that is only and always present in the
          // signature.
          if (name == "self") {
            continue;
          }

          auto default_value = arg.default_value();

          if (arg.kwarg_only()) {
            // If this is a kwarg, it needs to be emitted as keyword=value
            // in the definition and keyword=keyword in the call to
            // backend_execute.
            TORCH_INTERNAL_ASSERT(default_value.has_value());
            std::stringstream def_ss, fwd_ss;
            def_ss << name << "=";
            fwd_ss << name << "=" << name;
            default_value->repr(
                def_ss,
                [](std::ostream&, const IValue&) -> bool { return false; });
            def_inputs.emplace_back(def_ss.str());
            fwd_inputs.emplace_back(fwd_ss.str());
          } else {
            // If this is not a kwarg, it should be emitted as is in the
            // signature and the call to backend_execute.
            def_inputs.emplace_back(name);
            fwd_inputs.emplace_back(name);
          }
        }

        // Generate a comma-delimited list of identifiers to unpack outputs, as
        // well as a list of isinstance checks to make sure the backend returned
        // the types it was supposed to.
        std::stringstream out_ss, type_check_ss;
        std::vector<std::string> type_checks;
        TORCH_INTERNAL_ASSERT(schema.returns().size() == 1);
        auto out_ty = schema.returns().at(0).type();

        out_ss << "_0";
        type_check_ss << "assert isinstance(_0, ";

        if (auto out_tuple_ty = out_ty->cast<TupleType>()) {
          auto tuple_elements = out_tuple_ty->elements();
          type_check_ss << tuple_elements[0]->str() << ")";
          type_checks.emplace_back(type_check_ss.str());
          for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) {
            type_check_ss.str(std::string());
            type_check_ss.clear();
            out_ss << ", _" << i;
            type_check_ss << "assert isinstance(_" << i << ", "
                          << tuple_elements[i]->str() << ")";
            type_checks.emplace_back(type_check_ss.str());
          }
        } else {
          type_check_ss << out_ty->str() << ")";
          type_checks.emplace_back(type_check_ss.str());
        }

        method_te.v("def_inputs", def_inputs);
        method_te.v("fwd_inputs", fwd_inputs);
        method_te.v("refine", type_checks);
        method_te.s("ret", out_ss.str());

        loweredModule.define(
            method_ct.format(method_te), loweredModuleResolver());
      }

      // Run preprocess so that __processed_module is set correctly before
      // compilation.
      loweredModule.run_method(
          "__preprocess",
          cloned_module._ivalue(),
          toIValue(method_compile_spec, any_dict_ty).toGenericDict());

      // Call __setstate__ to ensure that the returned Module is ready to
      // run.
      auto state = at::ivalue::Tuple::create(
          toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
          loweredModule.attr("__processed_module"));
      loweredModule.run_method("__setstate__", state);
      return loweredModule;
    };
  }
};

} // namespace jit
} // namespace torch
